from typing import List, Callable, Dict, Optional, Union
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers.optimization import get_scheduler
from torch.nn.utils import clip_grad_norm_
from torch.distributed import get_rank, is_initialized


from utils import save_model_and_config
from dataset.cyclone import CycloneSample
from train.integrals import FluxIntegral
from eval.complex_metrics import ComplexMetrics


# TODO(diff) unused for diffusion. use metrics from the well, nrmse
# https://polymathic-ai.org/the_well/api/#the_well.benchmark.metrics.NRMSE
def relative_norm_mse(x, y, dim_to_keep=None, squared=True):
    assert x.shape == y.shape, "Mismatch in dimensions for computing loss"
    if x.ndim > 1:
        if dim_to_keep is None:
            x = x.flatten(1)
            y = y.flatten(1)
        else:
            # inference mode
            x = x.flatten(2)
            y = y.flatten(2)
    diff = x - y
    diff_norms = torch.linalg.norm(diff, ord=2, dim=-1)
    y_norms = torch.linalg.norm(y, ord=2, dim=-1)
    eps = 1e-8
    if squared:
        diff_norms, y_norms = diff_norms**2, y_norms**2
    if dim_to_keep is None:
        # sum over timesteps and mean over examples in batch
        return torch.mean(diff_norms / (y_norms + eps))
    else:
        dims = [i for i in range(len(y_norms.shape))][dim_to_keep + 1 :]
        return torch.mean(diff_norms / (y_norms + eps), dim=dims)


class LossWrapper(nn.Module):
    def __init__(
        self,
        weights: Dict,
        schedulers: Dict,
        denormalize_fn: Optional[Callable] = None,
        separate_zf: bool = False,
        real_potens: bool = False,
        loss_type: str = "mse",
        integral_loss_type: str = "mse",
        spectral_loss_type: str = "l1",
        dataset_stats: Optional[Dict] = None,
        ds: Optional[float] = None,
        ema_normalization_loss: Optional[List[str]] = None,
        ema_beta: float = 0.99,
        eval_loss_type: str = "mse",
        eval_integral_loss_type: str = "mse",
        eval_spectral_loss_type: str = "l1",
    ):
        super().__init__()
        self.weights = weights
        self._data_losses = ["df", "phi", "flux"]
        self._int_losses = ["flux_int", "phi_int"]
        # Add VAE and VQ-VAE losses
        self._vae_losses = ["kl_div"]
        self._vqvae_losses = ["vq_commit"]
        # Add spectral losses
        self._spectral_losses = [
            "kxspec",
            "kyspec",
            "qspec",
            "phi_zf",
            "kxspec_monotonicity",
            "kyspec_monotonicity",
            "qspec_monotonicity",
            "mass",
        ]

        # for integral losses
        self.integrator = FluxIntegral(
            real_potens=real_potens,
            flux_fields=False,  # return scalar integrals
            spectral_df=False,
        )

        # for spectral losses
        self.integrator_spec = FluxIntegral(
            real_potens=real_potens,
            flux_fields=True,  # get full flux fields for spectral diagnostics
            spectral_df=True,
        )
        self.denormalize_fn = denormalize_fn
        self.separate_zf = separate_zf
        self.schedulers = schedulers
        self.loss_type = loss_type
        self.integral_loss_type = integral_loss_type
        self.spectral_loss_type = spectral_loss_type

        # Separate loss types for evaluation (always consistent)
        self.eval_loss_type = eval_loss_type
        self.eval_integral_loss_type = eval_integral_loss_type
        self.eval_spectral_loss_type = eval_spectral_loss_type

        # use precomputed dataset statistics for normalization
        self.dataset_stats = dataset_stats or {}

        # ds for spectral diagnostics
        self.ds = ds

        # Loss normalization for gradient balancing
        self.loss_normalizer = {}
        self.normalize_losses = getattr(loss_type, "normalize_losses", False)

        # EMA normalization configuration
        self.ema_normalization_loss = ema_normalization_loss or []
        self.ema_beta = ema_beta
        self._ema_loss_scales = {}  # Store running averages of loss scales
        self._ema_initialized = set()  # Track which losses have been initialized

        try:
            from eval.complex_metrics import ComplexMetrics

            self.complex_metrics = ComplexMetrics()
        except ImportError:
            self.complex_metrics = None
            warnings.warn("ComplexMetrics not available, falling back to MSE")
            self.loss_type = "mse"

    def _update_ema_loss_scale(self, loss_name: str, loss_value: torch.Tensor):
        """Update EMA of loss scale for normalization."""
        if loss_name not in self.ema_normalization_loss:
            return

        current_scale = loss_value.detach().item()

        if loss_name not in self._ema_initialized:
            self._ema_loss_scales[loss_name] = current_scale
            self._ema_initialized.add(loss_name)
        else:
            # update EMA: ema = beta * ema + (1 - beta) * current
            self._ema_loss_scales[loss_name] = (
                self.ema_beta * self._ema_loss_scales[loss_name]
                + (1 - self.ema_beta) * current_scale
            )

    def _apply_ema_normalization(
        self, losses: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """Apply EMA-based normalization to specified losses."""
        normalized_losses = {}

        for loss_name, loss_value in losses.items():
            if (
                loss_name in self.ema_normalization_loss
                and loss_name in self._ema_loss_scales
            ):
                # Apply normalization: loss / ema_scale
                ema_scale = self._ema_loss_scales[loss_name]
                if ema_scale > 1e-8:  # Avoid division by near-zero
                    normalized_losses[loss_name] = loss_value / ema_scale
                else:
                    normalized_losses[loss_name] = loss_value
            else:
                normalized_losses[loss_name] = loss_value

        return normalized_losses

    def _get_current_loss_types(self):
        """Get the appropriate loss types based on training vs evaluation mode."""
        if self.training:
            return {
                "data_loss_type": self.loss_type,
                "integral_loss_type": self.integral_loss_type,
                "spectral_loss_type": self.spectral_loss_type,
            }
        else:
            return {
                "data_loss_type": self.eval_loss_type,
                "integral_loss_type": self.eval_integral_loss_type,
                "spectral_loss_type": self.eval_spectral_loss_type,
            }

    def compute_data_loss_with_type(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        loss_type: str,
        eps: float = 1e-8,
    ) -> torch.Tensor:
        """Compute loss with specified loss type."""
        if loss_type == "mse":
            return F.mse_loss(pred, target)
        elif loss_type == "l1":
            return F.l1_loss(pred, target)
        elif loss_type == "huber":
            return F.huber_loss(pred, target)
        elif loss_type == "smooth_l1":
            return F.smooth_l1_loss(pred, target)
        elif loss_type == "complex_mse":
            pred_complex = self.complex_metrics.to_complex(pred)
            target_complex = self.complex_metrics.to_complex(target)
            return self.complex_metrics.complex_mse(pred_complex, target_complex).mean()
        elif loss_type == "complex_l1":
            pred_complex = self.complex_metrics.to_complex(pred)
            target_complex = self.complex_metrics.to_complex(target)
            return self.complex_metrics.complex_l1(pred_complex, target_complex).mean()
        elif loss_type == "relative_mse":
            # relative MSE loss (normalized by target magnitude)
            relative_error = (pred - target) / (torch.abs(target) + eps)
            return (relative_error**2).mean()
        elif loss_type == "relative_l1":
            # relative L1 loss (normalized by target magnitude)
            relative_error = torch.abs(pred - target) / (torch.abs(target) + eps)
            return relative_error.mean()
        elif loss_type == "log_error":
            # log error loss - robust to orders of magnitude differences
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.mse_loss(pred_log, target_log)
        elif loss_type == "log_l1_error":
            # log L1 error loss
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.l1_loss(pred_log, target_log)
        elif loss_type == "log_cosh":
            # log-cosh loss - smooth version of absolute loss
            return torch.log(torch.cosh(pred - target)).mean()
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

    def get_ema_statistics(self) -> Dict[str, torch.Tensor]:
        """Get current EMA loss scales for logging/debugging."""
        return {
            f"ema_scale_{loss_name}": torch.tensor(scale, dtype=torch.float32)
            for loss_name, scale in self._ema_loss_scales.items()
        }

    def compute_data_loss(
        self, pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8
    ) -> torch.Tensor:
        """Compute loss based on the current mode (training vs evaluation)."""
        current_loss_types = self._get_current_loss_types()
        return self.compute_data_loss_with_type(
            pred, target, current_loss_types["data_loss_type"], eps
        )

    @property
    def all_losses(self):
        """Return all possible loss keys for this wrapper."""
        return (
            self._data_losses
            + self._int_losses
            + self._vae_losses
            + self._vqvae_losses
            + self._spectral_losses
        )

    def compute_integral_loss(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        loss_type: str = "mse",
        eps: float = 1e-8,
        loss_name: str = "flux_int",
    ) -> torch.Tensor:
        """
        Compute integral-specific losses with robust scaling for physics quantities.
        Designed for flux integrals and potential integrals that can vary by orders of magnitude.
        """
        if loss_type == "mse":
            return F.mse_loss(pred, target)
        elif loss_type == "relative_mse":
            # Relative MSE - robust to magnitude changes during non-steady state
            relative_error = (pred - target) / (torch.abs(target) + eps)
            return (relative_error**2).mean()
        elif loss_type == "relative_l1":
            # Relative L1 - more robust to outliers
            relative_error = torch.abs(pred - target) / (torch.abs(target) + eps)
            return relative_error.mean()
        elif loss_type == "log_error":
            # Log error - for values spanning multiple orders of magnitude
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.mse_loss(pred_log, target_log)
        elif loss_type == "adaptive_relative":
            # relative loss with running normalization
            alpha = 0.01  # EMA decay factor for the normalization
            if loss_name == "flux_int":
                if not hasattr(self, "_target_ema_flux"):
                    self._target_ema_flux = torch.abs(target).mean().item()
                else:
                    self._target_ema_flux = (
                        alpha * torch.abs(target).mean().item()
                        + (1 - alpha) * self._target_ema_flux
                    )
                scale = max(self._target_ema_flux, eps)

            if loss_name == "phi_int":
                if not hasattr(self, "_target_ema_phi"):
                    self._target_ema_phi = torch.abs(target).mean().item()
                else:
                    self._target_ema_phi = (
                        alpha * torch.abs(target).mean().item()
                        + (1 - alpha) * self._target_ema_phi
                    )
                scale = max(self._target_ema_phi, eps)

            normalized_error = (pred - target) / scale
            return (normalized_error**2).mean()
        elif loss_type == "int_norm_mse":
            # use stats from precomputed dataset statistics
            if loss_name == "flux_int" and "flux_std" in self.dataset_stats:
                scale = max(float(self.dataset_stats["flux_std"]), eps)
                normalized_error = (pred - target) / scale
                return (normalized_error**2).mean()
            elif loss_name == "phi_int" and "phi_std" in self.dataset_stats:
                scale = max(float(self.dataset_stats["phi_std"]), eps)
                normalized_error = (pred - target) / scale
                return (normalized_error**2).mean()
            else:
                # fallback to relative loss if we have no stats
                relative_error = (pred - target) / (torch.abs(target) + eps)
                return (relative_error**2).mean()
        elif loss_type == "int_norm_l1":
            # use stats from precomputed dataset statistics
            if loss_name == "flux_int" and "flux_std" in self.dataset_stats:
                scale = max(float(self.dataset_stats["flux_std"]), eps)
                normalized_error = torch.abs(pred - target) / scale
                return normalized_error.mean()
            elif loss_name == "phi_int" and "phi_std" in self.dataset_stats:
                scale = max(float(self.dataset_stats["phi_std"]), eps)
                normalized_error = torch.abs(pred - target) / scale
                return normalized_error.mean()
            else:
                # fallback to relative loss if we have no stats
                relative_error = torch.abs(pred - target) / (torch.abs(target) + eps)
                return relative_error.mean()
        else:
            raise ValueError(f"Unknown integral loss type: {loss_type}")

    def compute_vae_loss(
        self, preds: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """Compute KL divergence loss for VAE"""
        if "mu" not in preds or "logvar" not in preds:
            return {}

        mu = preds["mu"]
        logvar = preds["logvar"]
        # KL divergence: KL(q(z|x) || p(z)) where p(z) = N(0,I)
        # KL = -0.5 * mean/sum(1 + log(σ²) - μ² - σ²)?
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

        return {"kl_div": kl_loss}

    def compute_vqvae_loss(
        self, preds: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        vqvae_losses = {
            "vq_commit": preds["vq_commit_loss"] if "vq_commit_loss" in preds else None,
        }
        return vqvae_losses

    def get_loss_statistics(
        self, losses: Dict[str, torch.Tensor]
    ) -> Dict[str, Dict[str, float]]:
        """
        Compute statistics for each loss to help diagnose scale differences.
        Useful for understanding which losses dominate and need balancing.
        """
        stats = {}
        for name, loss_val in losses.items():
            if isinstance(loss_val, torch.Tensor):
                val = loss_val.item()
                if torch.isnan(loss_val) or not torch.isfinite(loss_val):
                    # Handle NaN or infinite values
                    stats[name] = {
                        "value": val,
                        "log10_value": float("nan"),
                        "magnitude_order": None,
                    }
                else:
                    # Normal case with finite values
                    log10_val = torch.log10(torch.clamp(loss_val, min=1e-12)).item()
                    stats[name] = {
                        "value": val,
                        "log10_value": log10_val,
                        "magnitude_order": int(log10_val),
                    }
        return stats

    def integral_loss(
        self,
        geometry: Dict[str, torch.Tensor],
        preds: Dict[str, torch.Tensor],
        tgts: Dict[str, torch.Tensor],
        idx_data: Optional[Dict[str, torch.Tensor]] = None,
        integral_loss_type: str = "mse",
    ):
        assert self.denormalize_fn is not None
        assert geometry is not None
        if self.training:
            pred_df = []
            pred_phi = []
            tgt_phi = []
            tgt_eflux = []
            for b, f in enumerate(idx_data["file_index"].tolist()):
                assert "df" in preds, "Integral losses requires df (5D)."
                pred_df.append(self.denormalize_fn(f, df=preds["df"][b]))
                if "phi" in preds:
                    if preds["phi"].ndim == 3:
                        preds["phi"] = preds["phi"].unsqueeze(0)
                    pred_phi.append(self.denormalize_fn(f, phi=preds["phi"][b]))
                if tgts["phi"].ndim == 3:
                    tgts["phi"] = tgts["phi"].unsqueeze(0)
                tgt_phi.append(self.denormalize_fn(f, phi=tgts["phi"][b]))
                tgt_eflux.append(self.denormalize_fn(f, flux=tgts["flux"][b]))
            pred_df = torch.stack(pred_df)
            if len(pred_phi) > 0:
                pred_phi = torch.stack(pred_phi)
            else:
                pred_phi = None
            tgt_phi = torch.stack(tgt_phi)
            tgt_eflux = torch.stack(tgt_eflux)
        else:
            # already denormalized for evaluation
            pred_df = preds["df"]
            pred_phi = preds["phi"] if "phi" in preds else None
            tgt_phi = tgts["phi"]
            tgt_eflux = tgts["flux"]

        if self.separate_zf and pred_df.shape[1] > 2:
            # merge zonal flow components
            if pred_df.shape[1] == 4:
                pred_df = pred_df[:, [0, 1]] + pred_df[:, [2, 3]]
            else:
                pred_df = torch.cat(
                    [pred_df[:, 0::2].sum(1, True), pred_df[:, 1::2].sum(1, True)],
                    dim=1,
                )

        pphi_int, (pflux, eflux, _) = self.integrator(geometry, pred_df, pred_phi)
        int_losses = {}
        int_losses_for_monitoring = {}

        # always monitor phi_int and flux_int loss for monitoring or backprop
        int_losses_for_monitoring["phi_int_mse"] = F.mse_loss(pphi_int, tgt_phi)
        int_losses_for_monitoring["flux_int_mse"] = torch.abs(pflux).mean() + F.l1_loss(
            eflux, tgt_eflux
        )

        # calc loss
        if integral_loss_type in [
            "relative_mse",
            "relative_l1",
            "log_error",
            "adaptive_relative",
            "int_norm_mse",
            "int_norm_l1",
        ]:
            int_losses["phi_int"] = self.compute_integral_loss(
                pphi_int, tgt_phi, integral_loss_type, loss_name="phi_int"
            )
            flux_loss_eflux = self.compute_integral_loss(
                eflux, tgt_eflux, integral_loss_type, loss_name="flux_int"
            )
            flux_loss_pflux = torch.abs(pflux).mean()
            int_losses["flux_int"] = flux_loss_pflux + flux_loss_eflux
        elif integral_loss_type == "mse":
            # use the already calculated MSE
            int_losses["flux_int"] = int_losses_for_monitoring["flux_int_mse"]
            int_losses["phi_int"] = int_losses_for_monitoring["phi_int_mse"]
        else:
            warnings.warn(
                f"Unknown integral loss type '{integral_loss_type}', using mse."
            )
            int_losses["phi_int"] = int_losses_for_monitoring["phi_int_mse"]
            int_losses["flux_int"] = int_losses_for_monitoring["flux_int_mse"]

        return (
            int_losses,
            int_losses_for_monitoring,
            {"phi": pphi_int, "pflux": pflux, "eflux": eflux},
        )

    def phi_fft(self, phi: torch.Tensor, norm: str = "forward"):
        """Convert phi to FFT domain for spectral analysis."""
        # work with float32 to avoid cuFFT half-precision error
        phi = phi.float()

        # convert to complex number
        if phi.shape[1] == 2:  # [real, imaginary] channels
            phi_complex = torch.view_as_complex(phi.permute(0, 2, 3, 4, 1).contiguous())
        else:
            # already complex or real tensor
            phi_complex = phi.squeeze(1).to(torch.complex64)

        # FFT + shift
        phi_fft = torch.fft.fftn(
            phi_complex, dim=(1, 3), norm=norm
        )  # (batch, x, s, y) -> FFT on (x, y)
        phi_fft = torch.fft.fftshift(phi_fft, dim=(1,))  # Shift on x dimension
        return phi_fft

    def diagnostics(
        self,
        phi_fft: torch.Tensor,
        eflux_field: torch.Tensor,
        ds: float,
        zf_mode: int = 0,
        aggregate: str = "mean",
    ):
        """Compute spectral diagnostics from phi_fft and eflux_field."""
        diag = {}

        # batch dimension - phi_fft shape: (batch, nx, ns, ny)
        batch_size = phi_fft.shape[0]
        nx, ns, ny = phi_fft.shape[1:]

        # Compute kxspec - sum over (s, y) dims, scale by ds
        kxspec = torch.sum(torch.abs(phi_fft) ** 2, dim=(2, 3)) * ds  # (batch, nx)
        if aggregate == "mean":
            diag["kxspec"] = torch.sum(kxspec, dim=1)  # (batch,)
        elif aggregate == "mid":
            diag["kxspec"] = kxspec[:, kxspec.shape[1] // 2]  # (batch,)
        else:  # aggregate == "none"
            diag["kxspec"] = kxspec

        # Compute kyspec - sum over (s, x) dims, scale by ds
        kyspec = torch.sum(torch.abs(phi_fft) ** 2, dim=(1, 2)) * ds  # (batch, ny)
        if aggregate == "mean":
            diag["kyspec"] = torch.sum(kyspec, dim=1)  # (batch,)
        elif aggregate == "mid":
            diag["kyspec"] = kyspec[:, kyspec.shape[1] // 2]  # (batch,)
        else:  # aggregate == "none"
            diag["kyspec"] = kyspec  # Keep full spectrum (batch, ny)

        # TODO: check if zonal flow is used
        # Compute zf profile from phi_fft
        fourier_zf = phi_fft.clone()
        # mask everything except the zf_mode
        fourier_zf[..., :zf_mode] = 0.0
        fourier_zf[..., zf_mode + 1 :] = 0.0
        fourier_zf = torch.fft.fftshift(fourier_zf, dim=(1,))  # Shift x dimension
        diag["phi_zf"] = torch.fft.irfftn(
            fourier_zf, dim=(1, 3), norm="forward", s=[nx, ny]
        )  # (batch, nx, ns, ny)

        # Compute flux spectrum - sum over velocity space dimensions
        # eflux_field shape should be (batch, vpar, vmu, s, x, y)
        if eflux_field.dim() == 6:  # (batch, vpar, vmu, s, x, y)
            diag["qspec"] = eflux_field.sum(
                (1, 2, 3, 4)
            )  # (batch, ny) - keep full spectrum
        elif eflux_field.dim() == 5:  # (vpar, vmu, s, x, y) - no batch
            diag["qspec"] = eflux_field.sum((0, 1, 2, 3))  # (ny,)
            diag["qspec"] = diag["qspec"].unsqueeze(0)  # Add batch dim (1, ny)
        else:
            warnings.warn(f"Unexpected eflux_field dimensions: {eflux_field.shape}")
            diag["qspec"] = torch.zeros(batch_size, ny, device=phi_fft.device)

        return diag

    def compute_spectral_loss(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        loss_type: str = "l1",
        eps: float = 1e-8,
    ) -> torch.Tensor:
        """Compute spectral loss with normalization options."""
        if loss_type == "l1":
            return F.l1_loss(pred, target)
        elif loss_type == "mse":
            return F.mse_loss(pred, target)
        elif loss_type == "normalized_l1" or loss_type == "normalized_mse":
            # GT normalization
            target_scale = torch.mean(torch.abs(target)) + eps
            normalized_pred = pred / target_scale
            normalized_target = target / target_scale
            if loss_type == "normalized_l1":
                return F.l1_loss(normalized_pred, normalized_target)
            else:  # loss_type == "normalized_mse"
                return F.mse_loss(normalized_pred, normalized_target)
        elif loss_type == "relative_l1":
            abs_diff = torch.abs(pred - target)
            target_magnitude = torch.abs(target) + eps
            relative_error = abs_diff / target_magnitude
            return torch.mean(relative_error)
        elif loss_type == "relative_mse":
            squared_diff = (pred - target) ** 2
            target_squared = target**2 + eps
            relative_squared_error = squared_diff / target_squared
            return torch.mean(relative_squared_error)
        elif loss_type == "log_l1":
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.l1_loss(pred_log, target_log)
        elif loss_type == "log_relative_l1":
            pred = pred / (pred.sum() + eps)
            target = target / (target.sum() + eps)
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.l1_loss(pred_log, target_log)
        elif loss_type == "log_mse":
            pred_log = torch.log(torch.abs(pred) + eps)
            target_log = torch.log(torch.abs(target) + eps)
            return F.mse_loss(pred_log, target_log)
        else:
            raise ValueError(f"Unknown spectral loss type: {loss_type}")

    def compute_spectral_losses(
        self,
        preds: Dict[str, torch.Tensor],
        tgts: Dict[str, torch.Tensor],
        geometry: Dict[str, torch.Tensor],
        use_normalization: bool = True,
    ) -> Dict[str, torch.Tensor]:
        """Compute spectral losses from predictions and targets."""
        spectral_losses = {}

        # Get the appropriate spectral loss type based on training/evaluation mode
        current_loss_types = self._get_current_loss_types()
        effective_spectral_loss_type = current_loss_types["spectral_loss_type"]

        # need ds to scale
        if self.ds is None:
            warnings.warn("ds parameter not set, skipping spectral losses")
            return spectral_losses

        # need df in preds and tgts to compute spectral losses
        if "df" not in preds or "df" not in tgts:
            warnings.warn("Missing df in predictions/targets, skipping spectral losses")
            return spectral_losses

        # calc integrals to get eflux fields and phi
        try:
            # eflux from integral computation
            # Handle zonal flow components
            if preds["df"].shape[1] == 4:
                # from (batch, 4, vpar, vmu, s, x, y) -> (batch, 2, vpar, vmu, s, x, y)
                preds_df = preds["df"][:, [0, 1]] + preds["df"][:, [2, 3]]
            elif preds["df"].shape[1] == 2:
                preds_df = preds["df"]
            else:
                raise ValueError(
                    f"Unexpected df shape: {preds['df'].shape}. Expected 2 or 4 channels."
                )
            # float32 to avoid cuFFT half-precision issues
            preds_df = preds_df.float()

            # Same for targets
            if tgts["df"].shape[1] == 4:
                tgts_df = tgts["df"][:, [0, 1]] + tgts["df"][:, [2, 3]]
            elif tgts["df"].shape[1] == 2:
                tgts_df = tgts["df"]
            else:
                raise ValueError(
                    f"Unexpected df shape: {tgts['df'].shape}. Expected 2 or 4 channels."
                )
            tgts_df = tgts_df.float()

            # calc integrals
            integrated_preds = self.integrator_spec(
                geometry, preds_df, preds.get("phi")
            )
            pred_phi, (pred_pflux, pred_eflux, pred_vflux) = integrated_preds
            integrated_tgts = self.integrator_spec(geometry, tgts_df, tgts.get("phi"))
            tgt_phi, (tgt_pflux, tgt_eflux, tgt_vflux) = integrated_tgts

            # Use provided phi if available, otherwise use computed phi from integrals
            pred_phi_for_fft = preds.get("phi", pred_phi)
            tgt_phi_for_fft = tgts.get("phi", tgt_phi)

            # to FFT domain
            pred_phi_fft = self.phi_fft(pred_phi_for_fft)
            tgt_phi_fft = self.phi_fft(tgt_phi_for_fft)

            # calc diagnostics
            pred_diag = self.diagnostics(pred_phi_fft, pred_eflux, self.ds)
            tgt_diag = self.diagnostics(tgt_phi_fft, tgt_eflux, self.ds)

            # calc basic spectral losses (using aggregated values)
            for k in ["kxspec", "kyspec", "qspec", "phi_zf"]:
                if k in pred_diag and k in tgt_diag:
                    spectral_losses[k] = self.compute_spectral_loss(
                        pred_diag[k], tgt_diag[k], effective_spectral_loss_type
                    )

            # calc monotonicity losses for certain spectra (using full spectra)
            # TODO not sure why you woudl need the full one, mean is enough
            pred_diag_full = self.diagnostics(
                pred_phi_fft, pred_eflux, self.ds, aggregate="none"
            )
            tgt_diag_full = self.diagnostics(
                tgt_phi_fft, tgt_eflux, self.ds, aggregate="none"
            )
            for k in ["qspec", "kyspec"]:
                monotonicity_key = f"{k}_monotonicity"
                try:
                    if k in pred_diag_full and k in tgt_diag_full:
                        # pred_diag_full[k] and tgt_diag_full[k] should be (batch, n_points)
                        pred_spec = torch.nan_to_num(
                            torch.log1p(pred_diag_full[k])
                        )  # (batch, n_points)
                        tgt_spec = torch.nan_to_num(
                            torch.log1p(tgt_diag_full[k])
                        )  # (batch, n_points)

                        # batched tensors - find peak for each batch
                        batch_monotonicity_losses = []
                        for b in range(pred_spec.shape[0]):
                            # find peak and compute tail monotonicity for this batch
                            peak_idx = torch.argmax(pred_spec[b]).item()
                            pred_tail = pred_spec[b, peak_idx:]
                            tgt_tail = tgt_spec[b, peak_idx:]

                            if len(pred_tail) > 1:
                                # finite differences
                                pred_diff = pred_tail[1:] - pred_tail[:-1]
                                tgt_diff = tgt_tail[1:] - tgt_tail[:-1]
                                tol = torch.clamp(tgt_diff.max(), min=0.0)
                                monotonicity_loss = torch.mean(
                                    torch.clamp(pred_diff - tol, min=0.0)
                                )
                                # # isotonic loss
                                # tail_sorted, _ = torch.sort(pred_tail, descending=True)
                                # monotonicity_loss = F.l1_loss(pred_tail, tail_sorted)
                                batch_monotonicity_losses.append(monotonicity_loss)

                        if batch_monotonicity_losses:
                            spectral_losses[monotonicity_key] = torch.stack(
                                batch_monotonicity_losses
                            ).mean()
                        else:
                            # fallback to zero loss if no valid tails found
                            spectral_losses[monotonicity_key] = torch.tensor(
                                0.0, device=pred_diag_full[k].device
                            )
                except Exception as e:
                    # if monotonicity computation fails (e.g., due to NaN), set to zero
                    warnings.warn(
                        f"Failed to compute {monotonicity_key}: {e}, setting to 0.0"
                    )
                    spectral_losses[monotonicity_key] = torch.tensor(
                        0.0,
                        device=(
                            pred_diag_full[k].device
                            if k in pred_diag_full
                            else torch.device("cpu")
                        ),
                    )

            # Mass conservation loss - use the combined df (with zonal flow added)
            spectral_losses["mass"] = self.compute_spectral_loss(
                preds_df.sum(),
                tgts_df.sum(),
                "log_l1",  # always use log_l1 for mass conservation
            )

        except Exception as e:
            warnings.warn(f"Failed to compute spectral losses: {e}")

        return spectral_losses

    def forward(
        self,
        preds: Dict[str, torch.Tensor],
        tgts: Dict[str, torch.Tensor],
        idx_data: Optional[Dict[str, torch.Tensor]] = None,
        geometry: Optional[Dict[str, torch.Tensor]] = None,
        compute_integrals: bool = True,
        progress_remaining: float = 1.0,
        separate_zf: bool = False,
        integral_loss_type: str = "mse",
    ):
        # Override loss types for evaluation to ensure consistency
        current_loss_types = self._get_current_loss_types()
        effective_integral_loss_type = current_loss_types["integral_loss_type"]

        losses = {}
        int_losses = {}
        int_losses_monitoring = {}

        if self.training:
            # update loss weights scheduler
            for key in self.schedulers.keys():
                if key in self.weights:
                    self.weights[key] = self.schedulers[key](progress_remaining)

        # NOTE: network predicts phi -> weight["phi_int"] = 0 (otherwise summed twice)
        # only compute integrals if requested by weights or in eval
        do_ints = not self.training and compute_integrals
        if sum([self.weights.get(k, 0.0) for k in self._int_losses]) > 0 or do_ints:
            int_losses, int_losses_monitoring, integrated = self.integral_loss(
                geometry, preds, tgts, idx_data, effective_integral_loss_type
            )

        # get VAE and VQ-VAE losses if needed
        vae_losses = {}
        vqvae_losses = {}
        if sum([self.weights.get(k, 0.0) for k in self._vae_losses]) > 0:
            vae_losses = self.compute_vae_loss(preds)
        if sum([self.weights.get(k, 0.0) for k in self._vqvae_losses]) > 0:
            vqvae_losses = self.compute_vqvae_loss(preds)

        # get spectral losses if needed
        spectral_losses = {}
        do_spectral = not self.training and compute_integrals
        if (
            sum([self.weights.get(k, 0.0) for k in self._spectral_losses]) > 0
            or do_spectral
        ):
            if geometry is not None:
                spectral_losses = self.compute_spectral_losses(preds, tgts, geometry)
            else:
                print(f"Geometry is None, cannot compute spectral losses")

        loss_keys = (
            [k for k, w in self.weights.items() if w > 0.0]
            if self.training
            else list(
                set(self.weights.keys())
                .union(set(int_losses.keys()))
                .union(set(vae_losses.keys()))
                .union(set(vqvae_losses.keys()))
                .union(set(spectral_losses.keys()))
            )
        )

        int_keys = [k for k in loss_keys if "int" in k]
        vae_keys = [k for k in loss_keys if k in self._vae_losses]
        vqvae_keys = [k for k in loss_keys if k in self._vqvae_losses]
        spectral_keys = [k for k in loss_keys if k in self._spectral_losses]

        data_keys = list(
            set(loss_keys)
            - set(int_keys)
            - set(vae_keys)
            - set(vqvae_keys)
            - set(spectral_keys)
        )

        if not all([k in preds for k in data_keys]):
            # warnings.warn("Prediction - DATA loss weight key mismatch.")
            missing_keys = [k for k in data_keys if k not in preds]
            for k in missing_keys:
                preds[k] = torch.zeros_like(tgts[k]).to(tgts[k].device)

        # compute losses
        for k in data_keys:
            if preds[k].shape != tgts[k].shape and k == "phi":
                preds[k] = preds[k].unsqueeze(0)
            if k == "df" and separate_zf:
                # Handle separate zonal flow components
                zf_loss = self.compute_data_loss(preds[k][:, :2], tgts[k][:, :2])
                other_loss = self.compute_data_loss(preds[k][:, 2:], tgts[k][:, 2:])
                losses[k] = zf_loss + other_loss
            else:
                losses[k] = self.compute_data_loss(preds[k], tgts[k])

        for k in int_losses:
            losses[k] = int_losses[k]
        for k in vae_losses:
            losses[k] = vae_losses[k]
        for k in vqvae_losses:
            losses[k] = vqvae_losses[k]
        for k in spectral_losses:
            losses[k] = spectral_losses[k]

        # Always compute regular MSE for monitoring (even when using different loss for training)
        mse_losses = {}
        if self.training:
            # no gradient for monitoring losses
            with torch.no_grad():
                for k in data_keys:
                    if k in preds and k in tgts:
                        if k == "df" and separate_zf:
                            # Handle separate zonal flow components for MSE
                            zf_mse = F.mse_loss(preds[k][:, :2], tgts[k][:, :2])
                            other_mse = F.mse_loss(preds[k][:, 2:], tgts[k][:, 2:])
                            mse_losses[f"{k}_mse"] = zf_mse + other_mse
                        else:
                            mse_losses[f"{k}_mse"] = F.mse_loss(preds[k], tgts[k])
                mse_loss = sum(mse_losses.values())

        if self.training:
            # Update EMA scales for specified losses
            for loss_name, loss_value in losses.items():
                if loss_name in self.ema_normalization_loss:
                    self._update_ema_loss_scale(loss_name, loss_value)

            # Apply EMA normalization before reweighting
            losses = self._apply_ema_normalization(losses)

            # reweight and accumulate - only use keys that actually exist in losses
            loss_keys_available = [k for k in loss_keys if k in losses]
            loss = sum([self.weights[k] * losses[k] for k in loss_keys_available])
            # filter active losses and add MSE monitoring metrics
            filtered_losses = {
                k: losses[k] for k, w in self.weights.items() if w > 0.0 and k in losses
            }
            filtered_losses.update({"total_mse": mse_loss})
            # add integral monitoring losses
            filtered_losses.update(int_losses_monitoring)
            # add EMA statistics for monitoring
            if self.ema_normalization_loss:
                filtered_losses.update(self.get_ema_statistics())

            return loss, filtered_losses
        else:
            # no reweight in validation - only use keys that actually exist in losses
            loss_keys_available = [k for k in loss_keys if k in losses]
            loss = sum([losses[k] for k in loss_keys_available])
            # combine all losses for evaluation
            all_losses = {**losses, **int_losses_monitoring}

            # Add loss statistics for debugging scale differences
            loss_stats = self.get_loss_statistics(all_losses)

            return loss, all_losses, integrated, loss_stats

    def normalize_loss_scales(
        self, losses: Dict[str, torch.Tensor], alpha: float = 0.01
    ) -> Dict[str, torch.Tensor]:
        """
        Normalize loss scales using exponential moving average for stable gradient balancing.

        Args:
            losses: Dictionary of loss tensors
            alpha: EMA decay factor for loss scale tracking

        Returns:
            Dictionary of normalized loss tensors
        """
        normalized_losses = {}

        for key, loss in losses.items():
            if key not in self.active_losses:
                normalized_losses[key] = loss
                continue

            # Track loss scale with EMA
            loss_magnitude = float(loss.detach().abs())
            if key not in self.loss_normalizer:
                self.loss_normalizer[key] = loss_magnitude
            else:
                self.loss_normalizer[key] = (1 - alpha) * self.loss_normalizer[
                    key
                ] + alpha * loss_magnitude

            # Normalize to unit scale (average magnitude = 1.0)
            scale_factor = 1.0 / (self.loss_normalizer[key] + 1e-8)
            normalized_losses[key] = loss * scale_factor

        return normalized_losses

    @property
    def active_losses(self):
        return [k for k in self.weights if self.weights[k] > 0.0]

    def __len__(self):
        return len(self.all_losses)


def _wide_min_norm_solution(
    units: torch.Tensor, weights: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
    """
    Compute minimal-norm solution of units x = weights for the
    underdetermined case (m << N) using Gram matrix inversion.
    units: (m, N)
    weights: (m,)
    Returns: x (N,)
    """
    # Filter out any (near) zero rows to avoid singular G
    row_norms = units.norm(dim=1)
    keep = row_norms > 0
    if keep.sum() == 0:
        return torch.zeros(units.shape[1], device=units.device, dtype=units.dtype)
    U = units[keep]  # (m', N)
    w = weights[keep]  # (m',)
    # Gram matrix
    G = U @ U.t()  # (m', m')
    # add regularization (for numerical stability)
    reg = eps * G.diag().mean()
    G = G + reg * torch.eye(G.size(0), device=G.device, dtype=G.dtype)
    # Solve G a = w
    a = torch.linalg.solve(G, w)  # (m',)
    # Compose parameter-space solution
    x = a @ U  # (N,)
    return x


class GradientBalancer(nn.Module):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        mode: str,
        scaler: torch.amp.GradScaler,
        clip_grad: bool = True,
        clip_to: float = 1.0,
        n_tasks: Optional[int] = None,
    ):
        super().__init__()

        self.optimizer = optimizer
        self.mode = mode
        self.clip_grad = clip_grad
        self.scaler = scaler
        self.clip_to = clip_to
        if mode in [None, "none"]:
            pass
        # conflict free gradnorm
        if mode in ["full"]:
            from conflictfree.grad_operator import ConFIGOperator
            from conflictfree.momentum_operator import PseudoMomentumOperator
            from conflictfree.utils import OrderedSliceSelector
            from typing import Union, Sequence, Optional

            # Override ConFIG_update function to handle wide matrices
            def ConFIG_update(
                grads: Union[torch.Tensor, Sequence[torch.Tensor]],
                weight_model=None,
                length_model=None,
                use_least_square: bool = True,
                losses: Optional[Sequence] = None,
            ) -> torch.Tensor:
                from conflictfree.weight_model import EqualWeight
                from conflictfree.length_model import ProjectionLength

                if weight_model is None:
                    weight_model = EqualWeight()
                if length_model is None:
                    length_model = ProjectionLength()

                if not isinstance(grads, torch.Tensor):
                    grads = torch.stack(grads)
                with torch.no_grad():
                    weights = weight_model.get_weights(
                        gradients=grads, losses=losses, device=grads.device
                    )
                    # Normalize each gradient vector to unit (or zero) row
                    units = torch.nan_to_num(
                        grads / (grads.norm(dim=1).unsqueeze(1)),
                        nan=0.0,
                        posinf=0.0,
                        neginf=0.0,
                    )

                    # If the model is extremely wide (many weights), skip lstsq directly
                    m, n = units.shape
                    # wide_case = n > 100000 and m <= 16  # heuristic threshold
                    wide_case = False
                    use_least_square = True

                    if use_least_square and not wide_case:
                        try:
                            best_direction = torch.linalg.lstsq(units, weights).solution
                            # best_direction = torch.linalg.pinv(units) @ weights  # pinv returns (N, m), result (N,)
                            # print first time we use lstsq that we are using it
                            if not hasattr(self, "_lstsq_used"):
                                print(
                                    f"Using lstsq for conflict-free gradient balancing with {m} tasks and {n} parameters."
                                )
                                self._lstsq_used = True
                        except Exception as e:
                            # fallback to wide minimal norm solution
                            best_direction = _wide_min_norm_solution(units, weights)
                            if not hasattr(self, "_lstsq_failed"):
                                print(
                                    f"lstsq failed ({e}), falling back to wide minimal norm solution."
                                )
                                self._lstsq_failed = True
                    else:
                        best_direction = _wide_min_norm_solution(units, weights)

                    return length_model.rescale_length(
                        target_vector=best_direction,
                        gradients=grads,
                        losses=losses,
                    )

            # monkey patch the ConFIG_update function
            import conflictfree.grad_operator

            conflictfree.grad_operator.ConFIG_update = ConFIG_update
            self.operator = ConFIGOperator()

        elif mode == "pseudo":
            from conflictfree.momentum_operator import PseudoMomentumOperator
            from conflictfree.utils import OrderedSliceSelector

            assert n_tasks is not None, "n_tasks must be specified for pseudo mode"

            print(f"Using Pseudo-Momentum gradient balancing with {n_tasks} tasks.")
            self.operator = PseudoMomentumOperator(n_tasks)
            self.loss_selector = OrderedSliceSelector()
            # Debug info
            print(f"  PseudoMomentumOperator initialized with {n_tasks} tasks")
            print(f"  Expected to cycle through task indices 0 to {n_tasks-1}")

    def forward(
        self, model: nn.Module, weighted_loss: torch.Tensor, losses: List[torch.Tensor]
    ):
        """Balances multitask gradients with conflict-free IG."""

        if self.mode in [None, "none"]:
            self.optimizer.zero_grad(set_to_none=True)
            self.scaler.scale(weighted_loss).backward()
        elif self.mode == "pseudo":
            from conflictfree.utils import get_gradient_vector

            self.optimizer.zero_grad(set_to_none=True)
            idx, loss_i = self.loss_selector.select(1, losses)

            # Debug info for first few steps
            if not hasattr(self, "_debug_step"):
                self._debug_step = 0
            if self._debug_step < 5:
                print(
                    f"Pseudo momentum step {self._debug_step}: selected idx={idx}, loss_value={loss_i.item():.6f}, total_losses={len(losses)}"
                )
                self._debug_step += 1

            self.scaler.scale(loss_i).backward()
            self.operator.update_gradient(model, idx, get_gradient_vector(model))
        elif self.mode == "full":
            from conflictfree.utils import get_gradient_vector

            grads = []
            for loss_i in losses:
                self.optimizer.zero_grad(set_to_none=True)
                # retain graph for multiple backward passes
                self.scaler.scale(loss_i).backward(retain_graph=True)
                grads.append(get_gradient_vector(model, none_grad_mode="zero"))  # noqa
            # apply conflict-free gradient directions
            self.operator.update_gradient(model, grads)

        # clipping
        if self.clip_grad:
            self.scaler.unscale_(self.optimizer)
            clip_grad_norm_(model.parameters(), self.clip_to)
        # gradient step
        self.scaler.step(self.optimizer)
        self.scaler.update()
        return model
